// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ==============================================================================
//
// Distributed XLA service protocol.
//
// This is a minimal distributed protocol intended for a small set of purposes
// * barriers to wait for all clients to start up or shut down
// * health checking to detect when clients vanish
// * for sharing GPU topology and NCCL communicator state between distributed
//   hosts.
//
// The intention is that a service is started during cluster initialization and
// persists for the lifetime of the cluster.

syntax = "proto3";

package xla;

option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/compiler/"
                    "xla/pjrt/distributed/protocol_go_proto";

// Describes a device local to a host.
message DeviceProto {
  int32 local_device_ordinal = 1;
  string name = 2;
  string vendor = 3;

  // The following fields are present in the GlobalTopologyProto message
  // returned by EnumerateDevices() but not in the LocalTopologyProto messages
  // passed to EnumerateDevices(). In other words, the coordinator node
  // determines the global device IDs during EnumerateDevices().
  int32 global_device_id = 4;  // Globally unique ID number.
  // Devices with the same slice_index are connected by fast network, e.g.
  // NVLink on GPUs.
  int32 slice_index = 5;
}

message LocalTopologyProto {
  int32 node_id = 1;
  // Unique per OS kernel restart to uniquely identify a host.
  // See /proc/sys/kernel/random/boot_id.
  string boot_id = 2;
  repeated DeviceProto devices = 3;
}

message GlobalTopologyProto {
  repeated LocalTopologyProto nodes = 1;
}

message ConnectRequest {
  int32 protocol_version = 1;
  int32 timeout_milliseconds = 2;

  // We assume that each node knows its globally-unique node ID, provided by
  // whatever mechanism launches the tasks. Node IDs should form a dense range
  // of integers [0, num_nodes).
  int32 node_id = 3;

  // A unique ID number for the client.
  uint64 client_id = 4;
}

message ConnectResponse {
  uint64 session_id = 1;
}

message EnumerateDevicesRequest {
  uint64 session_id = 1;
  LocalTopologyProto local_topology = 3;
}

message EnumerateDevicesResponse {
  GlobalTopologyProto global_topology = 1;
}

message KeyValueGetRequest {
  uint64 session_id = 1;
  bytes key = 2;
  int32 timeout_milliseconds = 3;
}

message KeyValueGetResponse {
  bool found = 1;
  bytes value = 2;
}

message KeyValueSetRequest {
  uint64 session_id = 1;
  bytes key = 2;
  bytes value = 3;
}

message KeyValueSetResponse {}

message WaitAtBarrierRequest {
  uint64 session_id = 1;
  bytes barrier_id = 2;
  int32 node_id = 3;
  int32 timeout_milliseconds = 4;
}

message WaitAtBarrierResponse {}

message HeartbeatRequest {
  uint64 session_id = 1;
  int32 node_id = 2;
}
message HeartbeatResponse {}

message ShutdownRequest {
  uint64 session_id = 1;
  int32 node_id = 2;
}
message ShutdownResponse {}

service DistributedRuntimeService {
  // Connects a node to the distributed coordinator node. Blocks until all tasks
  // have connected. The service receives the number of nodes to expect as an
  // option passed to its constructor.
  rpc Connect(ConnectRequest) returns (ConnectResponse) {}

  // Blocking enumeration of devices, used by the GPU backend only.
  // In parallel, all clients call EnumerateDevices() with their local device
  // topology, and receive back a global topology in response.
  rpc EnumerateDevices(EnumerateDevicesRequest)
      returns (EnumerateDevicesResponse) {}

  // Health-checking RPC. Workers send heartbeats to the coordinator at regular
  // intervals. If the worker does not hear from the coordinator or the
  // coordinator does not hear from the tasks, the tasks abort.
  rpc Heartbeat(HeartbeatRequest) returns (HeartbeatResponse) {}

  // Shutdown RPC. Workers send this RPC when they are ready to shut down; the
  // RPC blocks until all tasks have indicated they are ready to shut down,
  // or a timeout is reached.
  rpc Shutdown(ShutdownRequest) returns (ShutdownResponse) {}

  // Simple key-value store used for sharing configuration data.
  // For example, when using NCCL to communicate between multiple GPUs,
  // the NCCL communicator IDs are stored here.

  // Looks up a key in the key-value service. Blocks until the key is present
  // or until `timeout` expires.
  rpc KeyValueGet(KeyValueGetRequest) returns (KeyValueGetResponse) {}

  // Updates the value associated with a key.
  rpc KeyValueSet(KeyValueSetRequest) returns (KeyValueSetResponse) {}

  // Blocks until all nodes are at the barrier or the barrier times out.
  rpc WaitAtBarrier(WaitAtBarrierRequest) returns (WaitAtBarrierResponse) {}
}
